import os
import jax.numpy as np
import numpy as onp
import pandas as pd
import jax
import jax.random as npr
from tqdm import tqdm

import scalevi.utils as utils

######################################################################

################# Generate Genome data
######################################################################
def generate_genome_data(dir_path = None, score_fname = None, ratings_fname = None, nb = False):
    """
    Return: 
    genome matrix : M X S, where each ith row is the genome score for the ith movie
    movies_with_gscores_tracker: A index map of 25M dataset; 
                                            [ 
                                                if movie does not have genome scores : nan 
                                                otherwise : indice of the movie in the genome matrix
                                            ]
    """
    
    g_cols = ["movie_id", "tag_id", "relevance"]
    
    if (dir_path is None) & (nb == True):
        dir_path = "../data/datasets/25m" 
    elif (dir_path is None) & (nb == False):
        dir_path = "data/datasets/25m" 
    else:
        dir_path = dir_path
    score_fname = "/ml-25m/genome-scores.csv" if score_fname is None else score_fname 
    ratings_fname = "/ml-25m/ratings.csv" if ratings_fname is None else ratings_fname 
    
    df = pd.read_csv(dir_path + score_fname, names=g_cols, sep=",", skiprows=1)

    movies_w_scores, tags, rel = (
                          np.array(df.movie_id) - 1, 
                          np.array(df.tag_id) -1, 
                          np.array(df.relevance)
                         )
    del df
    
    r_cols = ["user_id", "movie_id", "rating", "unix_timestamp"]
    df = pd.read_csv(
        dir_path + ratings_fname, names=r_cols, sep=",", skiprows=1
    )
    movies_w_ratings = np.array(df.movie_id) - 1
    del df

    # get an array map each movie to a nice compact range
    movies = np.sort(onp.unique(movies_w_scores))  # use regular) onp to allow assignment

    # tracker for movies with gscores
    movies_w_scores_map = jax.ops.index_update((np.zeros(np.max(movies_w_ratings) + 1) + np.nan), # tracker initialized to nans
                                                movies, # sorted ids of movies with gscores
                                                np.arange(len(movies)) # reassign ids to movies with gscores
                                               ) 
    # checks out with previous tracker implementation
    # reshape the gscores to create a genome array
    # index 0 corresponds to the first movie with gscore

    n_movies = len(movies)
    n_tags = int(np.max(tags) + 1)
    genome = np.reshape(rel, [n_movies, n_tags])
    return genome, movies_w_scores_map

def save_genome_data(genome, movies_w_scores, nb = False):
    dir_path = utils.get_genome_dir(nb = nb)
    fname = utils.get_genome_fname()
    if not os.path.exists(dir_path ): os.makedirs(dir_path)
    np.savez(dir_path + fname, 
             genome= genome, 
             movies_w_gscores = movies_w_scores
            )
def load_genome_data(nb = False):
    """
    Return full genome data and map for the movies with gscores
    """
    dir_path = utils.get_genome_dir(nb = nb)
    fname = utils.get_genome_fname()
    rez = np.load(dir_path + fname)
    return [v for k,v in rez.items()]


######################################################################
################# Generate PCAed Genome data
######################################################################

def generate_PCAed_genome_data(dim, return_variance = False, nb = False):
    genome, _ = load_genome_data(nb = nb)
    def get_genome_var(dim):
        onp.random.seed(0)
        from sklearn.decomposition import PCA
        return PCA(dim).fit_transform(genome)
    return get_genome_var(dim)

def save_PCAed_genome_data(genomes, exp_vars=None, nb = False):
    # def save_genomes(genomes, exp_vars):
    dims = genomes.shape[-1]
    dir_path = utils.get_genome_dir(nb = nb)
    utils.create_dir(dir_path)
    fname = utils.get_PCAed_genome_fnames(dims)
    np.savez(dir_path+fname, genome = genomes, exp_var = exp_vars)
            
    # save_genomes(genomes, exp_vars)
        
        
def load_PCAed_genome_data(dim, return_exp_vars = False, nb = False):
    
    dir_path = utils.get_genome_dir(nb = nb)
    fname = utils.get_PCAed_genome_fnames(dim)

    if not os.path.exists(dir_path+fname): 
        raise ValueError("No data saved for the selected value of the dimensions.")
    
    rez = np.load(dir_path+fname)
    
    if return_exp_vars: return (rez['genome'], rez['exp_var'])
    else: return rez['genome']



######################################################################
################# Generate User data
######################################################################

def generate_users_data(dir_path = None, ratings_fname = None, nb = False):
    """
    Implement a function that:
        - Filters the ratings corresponding to the movies found in genome
        - Return this data of users, movie ids, ratings, raw_ratings
        - This gives the true number of overall possible ratings
    """
    _, movies_with_gscores = load_genome_data()

    if (dir_path is None) & (nb == True):
        dir_path = "../data/datasets/25m" 
    elif (dir_path is None) & (nb == False):
        dir_path = "data/datasets/25m" 
    else :
        dir_path = dir_path
    # dir_path = "../data/datasets/25m" if dir_path is None else dir_path
    ratings_fname = "/ml-25m/ratings.csv" if ratings_fname is None else ratings_fname 

    r_cols = ["user_id", "movie_id", "rating", "unix_timestamp"]
    df = pd.read_csv(
        dir_path + ratings_fname, names=r_cols, sep=",", skiprows=1
    )
    users = np.array(df.user_id) - 1
    movies_w_ratings = np.array(df.movie_id) - 1
    ratings = np.array(df.rating)

    del df
    
    movies_w_ratings_and_scores = movies_with_gscores[movies_w_ratings]
    # only keep ratings for movies that are in the genome
    mapof_w_scores_and_ratings = ~np.isnan(movies_w_ratings_and_scores)
    users = users[mapof_w_scores_and_ratings]
    movies = movies_w_ratings_and_scores[mapof_w_scores_and_ratings].astype(np.int64) # need the new indices here for genome reference
    raw_ratings = ratings[mapof_w_scores_and_ratings]
    ratings = raw_ratings >= 4.0
    
    return (users, movies, ratings, raw_ratings)

def save_users_data(users, movies, ratings, raw_ratings, nb = False):
    dir_path = utils.get_users_dir(nb = nb)
    if not os.path.exists(dir_path): os.makedirs(dir_path)
    fname = utils.get_users_fname()
    np.savez(dir_path + fname, 
             users = users,
             movies = movies,
             ratings = ratings, 
             raw_ratings = raw_ratings,
            )
def load_users_data(nb = False):
    """
    Return full users data: users, movies, ratings, raw ratings
    Uses load_npz_arrays: return all the saved arrays; 
    uses the same order for return as was used for saving
    """
    dir_path = utils.get_users_dir(nb = nb)
    fname = utils.get_users_fname()
    return utils.load_npz_arrays(dir_path+fname)

def test_saving_loading_user_data(nb = False):
    orig_data = generate_users_data(nb = nb)
    save_users_data(*orig_data)
    loaded_data = load_users_data()
    assert(utils.array_equal_list(orig_data, loaded_data))

def test_data_matches_with_Jutins(nb = False):
    """
    All ratings will not match, however, movies with ids less than 20699 will match
    """
    users, movies, ratings, raw_ratings = load_users_data(nb = nb)
    users_old, movies_old, raw_ratings_old, ratings_old, _ = utils.load_npz_arrays('../data/datasets/25m/ml-25m/ml-25m.npz')        
    assert(
        utils.array_equal_list(
            [users[:13], movies[:13], ratings[:13], raw_ratings[:13]] , 
            [users_old[:13], movies_old[:13], ratings_old[:13], raw_ratings_old[:13]] 
        )
    )
 


######################################################################
################# Generate leM User data
######################################################################


def generate_leM_user_data(M, nb = False):
    """
    Implement a function that:
        - Filters the users based on the number of ratings they have
        - Filters the movies ids, ratings, and raw ratings accordingly
        - Creates the mapping metadata
        - Creates corrected user ids
        - Returns corrected user ids, movies ids, ratings, raw ratings, users index mappings
    """
    users, movies, ratings, _ = load_users_data(nb = nb)
    assert(np.array_equal(users, np.sort(users)))
    uniq_users, users_start_ind, users_inv_map, users_counts = utils.np_arrays(onp.unique(users, 
                                                                              return_index = True, 
                                                                              return_counts = True, 
                                                                              return_inverse = True
                                                                             ))
    assert(np.array_equal(np.cumsum(users_counts)[:-1], users_start_ind[1:]))
    print(f"# of different users: {len(users_counts)}")
    print(f"# of maximum ratings for each user: {M}")
    mapof_users_w_leM = (users_counts<=M)[users_inv_map]
    
    users_leM = users[mapof_users_w_leM]
    movies_leM = movies[mapof_users_w_leM]
    ratings_leM = ratings[mapof_users_w_leM]

    assert(np.array_equal(users_leM, np.sort(users_leM)))
    print("Getting the metadata...")
    users_leM_ids, users_leM_indexmap = generate_leM_metadata(users_leM = users_leM, M = M)
    return users_leM_ids, users_leM_indexmap, movies_leM, ratings_leM


def generate_leM_metadata(users_leM, M):
    uniq_users_leM, users_leM_start_ind, users_leM_inv_map, users_leM_counts = utils.np_arrays(
                                                                                    onp.unique(
                                                                                        users_leM, 
                                                                                        True, 
                                                                                        True, 
                                                                                        True)
                                                                                    )
    assert(np.array_equal(np.cumsum(users_leM_counts)[:-1], users_leM_start_ind[1:]))
    users_leM_cumcounts = np.cumsum(users_leM_counts)
    users_leM_indexmap = np.array([
                            np.pad(
                            np.arange(users_leM_start_ind[i], users_leM_cumcounts[i]),
                            (0, M - users_leM_counts[i]),
                            'constant',
                            constant_values=(-1)
                            )
                            for i, u in list(enumerate(uniq_users_leM))
                    ])
    assert(np.array_equal(users_leM_inv_map, np.sort(users_leM_inv_map))) # asserts entry occur 
    assert(np.array_equal(onp.unique(users_leM_inv_map), np.arange(len(uniq_users_leM))))
    return users_leM_inv_map, users_leM_indexmap


def test_new_generate_data(users, movies, ratings):
    M = 10
    users_leM_ids, users_leM_indexmap, movies_leM, ratings_leM = generate_leM_user_data(users, movies, ratings, M = M)    
    
    uniq_users, invers_ind, users_counts = utils.np_arrays(onp.unique(users, return_inverse = True, return_counts = True))
    map_of_users_le_M_old = (users_counts <=M)[invers_ind]

    users_ = users[map_of_users_le_M_old]
    movies_ = movies[map_of_users_le_M_old]
    ratings_ = ratings[map_of_users_le_M_old]

    # plt.plot()
    # plt.show()

    def func_new(true_users, cutoff):
        users= np.array(true_users[:cutoff])
        uniq_users, users_inv_index, users_counts = utils.np_arrays(onp.unique(users, return_counts = True, return_inverse = True))
        max_len = np.max(users_counts)
        users_meta_arr = np.array([np.pad( np.nonzero(users == u)[0],
                                          (0, max_len - users_counts[i]), 
                                          'constant', 
                                          constant_values=(-1)
                                          )
                                    for i, u in enumerate(np.array(uniq_users))
                                  ])    
        return users_inv_index, users_meta_arr


    ids, index_map = func_new(users_, None)

    assert(np.array_equal(ids, users_leM_ids))
    assert(np.array_equal(index_map, users_leM_indexmap))
    assert(np.array_equal(movies_, movies_leM))
    assert(np.array_equal(ratings_, ratings_leM))


def save_leM_user_data(ids, index_map, movies, ratings, M, nb = False):
    dir_path =  utils.get_users_dir(nb = nb)
    fname = utils.get_leM_users_fname(M = M)
    np.savez(dir_path+fname, ids = ids, index_map = index_map, movies = movies, ratings = ratings, M = M)

def load_leM_user_data(M, nb = False):
    dir_path =  utils.get_users_dir(nb = nb)
    fname = utils.get_leM_users_fname(M = M)
    return utils.load_npz_arrays(dir_path+fname)

def test_save_load_leM_user_data(users, movies, ratings):
    M = 10
    (users_leM_ids, 
    users_leM_indexmap, 
    movies_leM, 
    ratings_leM) = generate_leM_user_data(users, movies, ratings, M = M)
    save_leM_user_data(
        users_leM_ids, users_leM_indexmap, movies_leM, ratings_leM, M)

    (users_leM_ids_, 
    users_leM_indexmap_, 
    movies_leM_, 
    ratings_leM_, M_) = load_leM_user_data(M)

    assert(
        utils.array_equal_list(
            [users_leM_ids, users_leM_indexmap, movies_leM, ratings_leM],
            [users_leM_ids_, users_leM_indexmap_, movies_leM_, ratings_leM_]))
    assert(M == M_)



######################################################################
################# Generate leM N user data 
######################################################################

def generate_leM_N_user_data(M, N, nb):
    users_leM_ids, users_leM_indexmap, movies_leM, ratings_leM, M = load_leM_user_data(M = M, nb = nb)
    users_leM_indexmap_N = users_leM_indexmap[:N, :]
    S = np.max(users_leM_indexmap_N)+1
    users_leM_ids_N = users_leM_ids[:S]
    movies_leM_N = movies_leM[:S]
    ratings_leM_N = ratings_leM[:S]
    return users_leM_ids_N, users_leM_indexmap_N, movies_leM_N, ratings_leM_N

def get_leM_N_data(M, N, nb):
    try : 
        print("Trying to retrieve the save data..")
        (
            users_leM_ids_N, 
            users_leM_indexmap_N,
            movies_leM_N,
            ratings_leM_N) = generate_leM_N_user_data(M, N, nb)
        return (users_leM_ids_N, 
                users_leM_indexmap_N, 
                movies_leM_N, 
                ratings_leM_N)
    except Exception:
        # assuming the error is because of not finding the leM data; let's generate it, save it, and retry
        try : 
            print(f"Data for users with at most ratings {M} not found.")
            print("Trying to generate the data with users less than {M} ratings.")
            leM_data = generate_leM_user_data(M = M, nb = nb)
            print("Trying to save the data with users less than {M} ratings.")
            save_leM_user_data(*leM_data, M, nb = nb)
            print("Data saved for less than {M} ratings.")
        except Exception:
            # assuming the error is because of not finding the user data; let's generate it, save it, and retry
            try : 
                print(f"Could not find the data for users. Trying to generate it from original files.")
                orig_data = generate_users_data(nb = nb)
            except Exception: 
                # assuming the error is because of not finding the genome data; let's generate it, save it, and retry
                genome_data = generate_genome_data(nb = nb) # will reuire dir_path and fnames for the csv data
                save_genome_data(*genome_data, nb = nb)
                orig_data = generate_users_data(nb = nb) # will require dir path and fnames for csv data
                save_users_data(*orig_data, nb = nb)
            leM_data = generate_leM_user_data(M = M, nb = nb)
            save_leM_user_data(*leM_data, M, nb = nb)
        print("Hopefully done generating the data.")
        users_leM_ids_N, users_leM_indexmap_N, movies_leM_N, ratings_leM_N = generate_leM_N_user_data(M, N, nb = nb)
        print("Data retrieved.")
        return (users_leM_ids_N, users_leM_indexmap_N, movies_leM_N, ratings_leM_N)

def get_train_test_split_leM_N_data(M, N, split, nb = False, force_generate = False):
    (test_users_leM_indexmap,
     train_users_leM_indexmap, 
     movies_leM, 
     ratings_leM) = get_leM_train_test_data(M, split, nb, force_generate)

    train_users_leM_indexmap_N = train_users_leM_indexmap[:N, :]
    test_users_leM_indexmap_N = test_users_leM_indexmap[:N, :]
    return (train_users_leM_indexmap_N, 
            test_users_leM_indexmap_N, 
            movies_leM, 
            ratings_leM)

def get_leM_train_test_data(M, split, nb, force_generate):
    if force_generate:
        train_id_map, test_id_map, movies, ratings = generate_train_test_split_leM_data(M, split, nb)
        save_leM_train_test_data(train_id_map, test_id_map, movies, ratings, M, split, nb)
        return train_id_map, test_id_map, movies, ratings
    else:
        try:
            train_id_map, test_id_map, movies, ratings, M, split = load_leM_train_test_data(M, split, nb) 
            return train_id_map, test_id_map, movies, ratings
        except FileNotFoundError:
            # probably could not find the file to load from. 
            train_id_map, test_id_map, movies, ratings = generate_train_test_split_leM_data(M, split, nb)
            save_leM_train_test_data(train_id_map, test_id_map, movies, ratings, M, split, nb)
            return train_id_map, test_id_map, movies, ratings
        except:
            raise

def generate_train_test_split_leM_data(M, split, nb= False):
    # TODO: Replace the load_leM_user_data with a get_leM_user_data
    (users_leM_ids, 
     users_leM_indexmap, 
     movies_leM, 
     ratings_leM, M) = load_leM_user_data(M = M, nb = nb)


    def create_test_and_train_set(idx_map, split):
        onp.random.seed((idx_map>-1).sum())
        # some constant quantities
        num_total_ratings = int((idx_map > -1).sum())
        num_test_ratings = int(split*num_total_ratings)

        # the code requires random sized array on fly
        # better to use original numpy

        idx_map = onp.array(idx_map)
        test_idx_map = -1*onp.ones_like(idx_map)

        for i in tqdm(range(idx_map.shape[0])):

            row = idx_map[i]

            n_ratings = (row>-1).sum()
            n_test_ratings = max(int(split*n_ratings), 1)

            arr = onp.arange(n_ratings)
            onp.random.shuffle(arr)
            test_rating_ids = arr[:n_test_ratings]

            test_idx_map[i][:n_test_ratings] = row[test_rating_ids]
            idx_map[i][test_rating_ids] = -1 

        return np.array(test_idx_map), np.array(idx_map)



    (test_users_leM_indexmap, 
     train_users_leM_indexmap) = create_test_and_train_set(
                                        users_leM_indexmap,
                                        split)
    return (test_users_leM_indexmap,
            train_users_leM_indexmap, 
            movies_leM, 
            ratings_leM)

def save_leM_train_test_data(train_id_map, test_id_map, movies, ratings, M, split, nb = False):
    dir_path =  utils.get_users_dir(nb = nb)
    fname = utils.get_leM_train_test_fname(M = M, split = split)
    np.savez(dir_path+fname, 
            train_id_map = train_id_map, 
            test_id_map = test_id_map, 
            movies = movies, 
            ratings = ratings, 
            M = M,
            split = split)

def load_leM_train_test_data(M, split, nb = False):
    dir_path =  utils.get_users_dir(nb = nb)
    fname = utils.get_leM_train_test_fname(M = M, split = split)
    return utils.load_npz_arrays(dir_path+fname)

def check_leM_train_test_data(M, split, nb):
    train_id_map, test_id_map, movies, ratings = get_leM_train_test_data(M, split, nb)
    train_id_map_, test_id_map_, movies_, ratings_, _, _ = load_leM_train_test_data(M, split, nb)
    assert(
        utils.array_equal_list(
        [train_id_map, test_id_map, movies, ratings],
        [train_id_map_, test_id_map_, movies_, ratings_]))
    

# def get_leM_N_toydata(M, N, D, nb):
#     (users_leM_ids_N, users_leM_indexmap_N, movies_leM_N, ratings_leM_N) = get_leM_N_data(M, N, D, nb)
#     genome = get_PCAed_toygenome(D = D, nb = nb)
#     p = model.RSModel(
#                         users = users_leM_ids_N, 
#                         users_metadata = users_leM_indexmap_N,
#                         movies = movies_leM_N, 
#                         ratings = ratings_leM_N, 
#                         genome = genome, 
#                         )
#     _, _, toy_ratings = p.forward_sample(jax.random.PRNGKey(0), None)

#     return (users_leM_ids_N, users_leM_indexmap_N, movies_leM_N, toy_ratings)


def get_PCAed_genome(D, nb):
    try : 
        genome = load_PCAed_genome_data(dim = D, nb = nb)
        return genome
    except: 
        # assuming the error is because of not finding the PCAed genome data; let's generate it, save it, and retry
        try: 
            genome_data = generate_PCAed_genome_data(dim = D, nb = nb)
            # save_PCAed_genome_data(*genome_data, nb = nb)
            # genome = load_PCAed_genome_data(dim = D, nb = nb)
            return genome_data
        except Exception:
            # assuming the error is because of not finding the genome data; let's generate it, save it, and retry
            genome_data = generate_genome_data(nb = nb) # will reuire dir_path and fnames for the csv data
            save_genome_data(*genome_data, nb = nb)
            genome_data = generate_PCAed_genome_data(dim = D, nb = nb)
            # save_PCAed_genome_data(*genome_data)
            # genome = load_PCAed_genome_data(dim = D)
            return genome_data

def get_PCAed_toygenome(D, nb):
    return load_PCAed_toygenome_data(dim = D, nb = nb)

def load_PCAed_toygenome_data(dim, nb):
    onp.random.seed(0)
    return np.array(onp.random.normal(size = (40000,dim)))


# DIR_PATH = "data/datasets/movielens25M"
# SCORE_FNAME = "data/datasets/movielens25M"
# DIR_PATH = "data/datasets/movielens25M"


def get_data(config_dict, nb):

    print(f"Getting user data ...")
    (train_leM_indexmap_N, test_leM_indexmap_N, 
        movies_leM_N, ratings_leM_N) = get_train_test_split_leM_N_data(
                                                M = config_dict.get('M_data', 1000), 
                                                N = config_dict['N_leaves'],
                                                split = config_dict.get("RS_data_split", 0.1), 
                                                nb = nb, 
                                                force_generate=config_dict.get('RS_force_generate', False))
    print(f"Getting genome data ...")

    return {
        "train_ids": train_leM_indexmap_N, 
        "test_ids": test_leM_indexmap_N, 
        "movies": movies_leM_N, 
        "ratings": ratings_leM_N,
        "genome": get_PCAed_genome(D = config_dict['D_data'], nb = nb),
    }
